package ffte.core.streamed

import spinal.core._
import spinal.lib._

import ffte.types.VecComplex
import ffte.types.FixComplex
import ffte.property.{getFFTIW,getFFTOW,FFTIOConfig,getTwiddleResolution}
import ffte.algorithm.{circularOrder,FFTMisc}
import ffte.algorithm.FFTGen.{Method,wpbFFT}
import ffte.core.samewidthed
import ffte.misc

/**
 * Frequency corrector for Winograd Prime-Based FFT
 *
 * This class implements frequency correction for Winograd Prime-Based (WPB) FFT
 * implementations in the streamed domain. WPB algorithms require special handling
 * of frequency components due to the pre/post-processing stages and the way
 * Winograd minimal multiplication algorithms are structured.
 *
 * Mathematical Foundation:
 * Winograd algorithms separate the DC component from other frequency
 * components and apply different processing strategies. The frequency corrector
 * ensures proper ordering and scaling of output frequency components.
 *
 * Architecture Features:
 * - ROM-based twiddle factor storage for frequency correction
 * - Separate handling of DC and non-DC components
 * - Configurable correction method based on transform size
 * - Streaming interface with flow control
 *
 * Performance Characteristics:
 * - Optimized for Winograd algorithm requirements
 * - Minimal additional latency (typically 1-2 cycles)
 * - Efficient ROM usage for correction factors
 * - Proper frequency component ordering
 *
 * @tparam T Streamed method type parameter
 * @param dW Data width for complex numbers
 * @param p WPB decomposition structure containing algorithm information
 */
sealed class freqCorr[T <: Streamed](val dW: Int, val p: wpbFFT[T]) extends streamedComponents with samewidthed {
    val io = iogen
    def method = if(p.N>5) 2 else 1
    val tW = 1-getTwiddleResolution()
    val oc = circularOrder(p.N)
    val (wf,ws) = oc.wfft(tW)

    io.d.ready := io.q.ready
    val fire = io.d.fire
    io.q.valid := fire

    val cnt = Counter(p.N-1,fire)
    when(io.d.payload.last & fire) {
        cnt.clear
    }

    val mem = Mem(FixComplex(tW),initialContent=(for(i <- 1 until p.N) yield {
        val f = FixComplex(tW)
        f.re.fromInt(wf(p.p.first.store_tab(i-1))._1)
        f.im.fromInt(wf(p.p.first.store_tab(i-1))._2)
        f
    }))
    val d = if(method>1) RegNextWhen(io.d.payload.fragment,fire) else io.d.payload.fragment
    val f = if(method>1) mem.readSync(
        address = cnt.value,
        enable  = fire
    ) else {
        mem(cnt.value)
    }
    def delay = if(method>1) 2 else 1
    val df = (d*f).cut(dW,-getTwiddleResolution(),d.resolution)
    val q = if(method>1) RegNextWhen(df,fire) else Delay(df,delay,fire)
    io.q.payload.last     := Delay(io.d.payload.last,delay,fire)
    io.q.payload.fragment := q
}

class wpb[T<:Streamed](val S:Int,val p:wpbFFT[T]) extends combineStreamedComponents[T] {
    val increase = Math.round(Math.log(p.N-1)/Math.log(2)-S).toInt
    val owW = oW+increase
    val s_block = FFTIOConfig(interW,owW) on new Area {
        val u = sg(second,"s") 
        val delay = u.delay
    }
    
    val s_block_q = s_block.u.io.q 

    val xcorr = new freqCorr[T](interW,p)
    val ws = xcorr.ws
    f_block.u.io.q >> xcorr.io.d
    xcorr.io.q >> s_block.u.io.d
    
    io.d.ready := io.q.ready
    val fire = io.d.fire
    io.q.valid := fire
    s_block_q.ready := io.q.ready

    val icnt = Counter(p.N,fire)
    when(io.d.payload.last & fire) {
        icnt.clear
    }
    val valid = io.d.fire
    when(icnt.value===0) {
        valid := False
    }
    
    f_block.u.io.d.valid := valid
    f_block.u.io.d.payload.last     := io.d.payload.last
    f_block.u.io.d.payload.fragment := io.d.payload.fragment

    val sum   = Reg(VecComplex(G,owW+S))
    val sum_d = Reg(VecComplex(G,owW))
    val d0    = Reg(VecComplex(G,dW))
    val d0_d  = Reg(VecComplex(G,owW))
    
    when(fire) {
        when(icnt===0) {
            sum   := VecComplex(G,owW+S).zero
            sum_d := sum.cut(owW,S)
            d0    := io.d.payload.fragment
            d0_d  := d0.resize(owW+S).cut(owW,S)
        }.otherwise {
            sum   := sum+io.d.payload.fragment.resize(owW+S)
        }
    }
    val sub_delay = xcorr.delay + f_block.delay + s_block.delay
    val fdelay = sub_delay + sub_delay/(p.N-1)
    val fdelay_n = (fdelay-1)/p.N-1
    
    val d0_dd  = if(fdelay_n>0) Delay(d0_d,fdelay_n,(icnt.value===0 & fire)) else d0_d
    val sum_dd = if(fdelay_n>0) Delay(sum_d,fdelay_n,(icnt.value===0 & fire)) else sum_d

    val ocnt = Counter(p.N,fire)
    when(s_block_q.payload.last & s_block_q.fire ) {
        ocnt.clear
    }
    val delay_one = RegInit(False)
    val d0_s      = Reg(VecComplex(G,owW))
    val sq        = VecComplex(G,oW)
    val sq_d      = Reg(VecComplex(G,oW))
    val q         = Reg(VecComplex(G,oW))
    val q0        = Reg(VecComplex(G,oW))
    val od        = VecComplex(G,owW)
    od  := s_block_q.payload.fragment.resize(owW+ws).cut(owW,ws)
    sq := (d0_s + od).fixTo(oW,0)
    when(fire) {
        when(ocnt===p.N-1 | (s_block_q.payload.last & s_block_q.fire)) {
            q0        := (sum_dd+d0_dd).fixTo(oW,0) 
            d0_s      := d0_dd
        }
        when(icnt===0) {
            delay_one := False
        }.elsewhen(ocnt===0) {
            delay_one := True
        }
        when(s_block.u.io.q.fire) {
            sq_d := sq
        }
        when(ocnt===0) {
            q := q0
        }.otherwise { 
            when(delay_one) {
                q := sq_d
            }.otherwise{ 
                q := sq
            }
        }
    }
    def delay = fdelay+2 - (if(fdelay%p.N==0) 1 else 0)
    io.q.payload.fragment := q
    io.q.payload.last     := misc.delay(delay,p.N).delayPulse(io.d.payload.last,fire)
}